package org.zjvis.datascience.service;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.util.Arrays;

import javax.annotation.PostConstruct;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import com.mayabot.nlp.blas.Vector;
import com.mayabot.nlp.fasttext.FastText;
import org.zjvis.datascience.common.util.ToolUtil;

/**
 * @description FastText 模型 预测服务Service
 * @date 2021-12-23
 */
@Service
public class FastTextService {
	
	private final static Logger logger = LoggerFactory.getLogger(FastTextService.class);

	private FastText autojoinModel;

//	private FastText mappingModel;
	
	@Autowired
	private MinioService minioService;
	
	@PostConstruct
	private void init() throws Exception {
		try {
			autojoinModel = loadModel("autojoin");
//			mappingModel = loadModel("mapping");
		} catch (Exception e) {
			logger.error("load fasttext model error...", e);
		}
	}
	
	public FastText getAutojoinModel() {
		return this.autojoinModel;
	}

//	public FastText getMappingModel() {
//		return this.mappingModel;
//	}

	public void convertModel(String modelPath, String savePath, boolean compress) throws Exception {
		FastText autojoinModel = FastText.loadCppModel(new File(modelPath));
		if (compress){
			autojoinModel = autojoinModel.quantize(2, false, false);
		}
		autojoinModel.saveModelToSingleFile(new File(savePath));
	}

	private FastText loadModel(String fileName) throws Exception {
		String modelPath = System.getProperty("user.dir") + "/" + fileName;
		File modelFile = new File(modelPath);
		if (!modelFile.exists()){
			logger.info("Fasttext model " + fileName + " not exist, downloading...");
			InputStream mappingStream = minioService.getObject("fasttext-model", fileName);
			minioService.writeToLocal(modelPath, mappingStream);
		}
		FastText model = FastText.loadModelFromSingleFile(new File(modelPath));
		logger.info("Loaded fasttext model " + fileName);
		return model;
	}
	  
  public double getFasttext(FastText model, String str1, String str2) {
	str1 = ToolUtil.standardString(str1, ' ');
	double[] v1 = getTextVector(model, str1);
	str2 = ToolUtil.standardString(str2, ' ');
	double[] v2 = getTextVector(model, str2);
	double ret = computeVectorSimilarity(v1, v2);
	return ret;
  }

  public double[] getTextVector(FastText model, String text) {
	Vector rawRes = model.getSentenceVector(Arrays.asList(text.split(" ")));
	double[] res = vectorToArray(rawRes);
	return res;
  }

  public double[] vectorToArray(Vector v) {
	double[] res = new double[v.length()];
	for (int i = 0; i < v.length(); i++) {
	  res[i] = v.get(i);
	}
	return res;
  }

  private double normArray(double[] v) {
	double res = 0d;
	for (int i = 0; i < v.length; i++) {
	  res += Math.pow(v[i], 2);
	}
	return Math.sqrt(res);
  }

  private double mulArray(double[] v1, double[] v2) {
	double res = 0d;
	for (int i = 0; i < v1.length; i++) {
	  res += v1[i] * v2[i];
	}
	return res;
  }

  public double computeVectorSimilarity(double[] v1, double[] v2) {
	//计算两个向量之间的余弦相似度
	double num = mulArray(v1, v2);
	double den = normArray(v1) * normArray(v2);
	double res = 0d;
	if (den != 0) {
	  res = num / den;
	  res = res > 0 ? res : 0;
	}
	return res;
  }
	
}
